Skip to content

feat: batch KV block copies via cudaMemcpyBatchAsync in fs connector#607

Merged
kfirtoledo merged 2 commits into
llm-d:mainfrom
kfirtoledo:batch-memcpy
Jun 4, 2026
Merged

feat: batch KV block copies via cudaMemcpyBatchAsync in fs connector#607
kfirtoledo merged 2 commits into
llm-d:mainfrom
kfirtoledo:batch-memcpy

Conversation

@kfirtoledo

Copy link
Copy Markdown
Collaborator

Summary

Replace the per-(block, layer) cudaMemcpyAsync loop in TensorCopier with a single cudaMemcpyBatchAsync (CUDA 12.8+) submission. Submits all descriptors in one driver call, removing per-call dispatch overhead.

  • Enabled by default; toggle off via USE_BATCH_MEMCPY_READ=0 / USE_BATCH_MEMCPY_WRITE=0.
  • Per-call DMA loop kept as fallback for older CUDA toolkits and A/B debugging.
  • srcAccessOrder=ANY set on cudaMemcpyAttributes (matches vLLM's simple_kv_offload/cuda_mem_ops.py).
  • #if CUDA_VERSION handles the failIdx out-param that CUDA 13 dropped.

Measured impact (128k tokens, TP=4, --block-size 512)

Workload no-BATCH BATCH speedup
gpt-oss-20b (HMA) cold 3.35s 1.77s 1.9x
gpt-oss-20b (HMA) hot 0.58s 0.30s 1.9x
gpt-oss-120b (HMA) cold 4.95s 2.51s 2.0x
gpt-oss-120b (HMA) hot 0.83s 0.34s 2.4x
Llama-3.1-8B hot 0.45s 0.43s neutral
Llama-3.1-70B hot 0.84s 0.85s neutral

Big wins on HMA models where per-layer DMAs are small; neutral on Llama/no-HMA where each per-call copy is already large enough that driver dispatch is amortized.

Test plan

  • `make test` — 30 passed, 3 skipped (all on this branch, batched path active by default).
  • Manual storage roundtrip on Llama-3.1-8B / 70B + gpt-oss-20b/120b with batch enabled and disabled (see table).

@github-actions github-actions Bot added the size/XL Denotes a PR that changes 500-999 lines, ignoring generated files. label May 26, 2026
@github-actions github-actions Bot added size/L Denotes a PR that changes 100-499 lines, ignoring generated files. and removed size/XL Denotes a PR that changes 500-999 lines, ignoring generated files. labels May 26, 2026

@Etelis Etelis left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fallback on << 12800 o.w LGTM

Comment thread kv_connectors/llmd_fs_backend/csrc/storage/tensor_copier.cu Outdated
// Batched DMA path: one cudaMemcpyBatchAsync covers all per-(block, layer)
// copies for the blocks in this file (num_blocks * num_tensors).
// The batch executes in stream order; ordering within the batch is unspecified.
void TensorCopier::copy_blocks_via_batch_memcpy(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's relevant, but HIP (AMD) also supports batch memcpy out of the box, and might be worth adding as well.

Comment thread kv_connectors/llmd_fs_backend/csrc/storage/tensor_copier.cu
Comment thread kv_connectors/llmd_fs_backend/csrc/storage/tensor_copier.cu Outdated
@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

Unsigned commits detected! Please sign your commits.

For instructions on how to set up GPG/SSH signing and verify your commits, please see GitHub Documentation.

@Etelis Etelis left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Submit all per-(block, layer) copies in one driver call instead of N
cudaMemcpyAsync calls. Enabled by default; toggle off with
USE_BATCH_MEMCPY_READ / USE_BATCH_MEMCPY_WRITE=0. Requires CUDA 12.8+.

Speeds up KV-cache offload writes/reads when per-layer DMA sizes are
small enough that driver dispatch dominates.

Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com>
cudaMemcpyBatchAsync was introduced in CUDA 12.8 — guard the batch
path with #if CUDA_VERSION >= 12080 and route to the per-call
cudaMemcpyAsync loop below that. Default USE_BATCH_MEMCPY_* off on
older toolchains so the env knob still makes sense.

Also drop thread_local on the attrs/attrs_idx inputs (never mutated,
no per-thread duplication needed) and move the copy_blocks dispatcher
below the helpers it dispatches to.

Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com>

@dannyharnik dannyharnik left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
/approve

@kfirtoledo kfirtoledo merged commit 59aa98f into llm-d:main Jun 4, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/L Denotes a PR that changes 100-499 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants